#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Explanation class, with visualization functions.
"""
import json
import os
import os.path
import string
from io import open

import numpy as np

# from .exceptions import LimeError
from LIME.exceptions import LimeError
from sklearn.utils import check_random_state


def id_generator(size=15, random_state=None):
    """Helper function to generate random div ids. This is useful for embedding
    HTML into ipython notebooks."""
    chars = list(string.ascii_uppercase + string.digits)
    return "".join(random_state.choice(chars, size, replace=True))


class DomainMapper(object):
    """Class for mapping features to the specific domain.

    The idea is that there would be a subclass for each domain (text, tables,
    images, etc), so that we can have a general Explanation class, and separate
    out the specifics of visualizing features in here.
    """

    def __init__(self):
        pass

    def map_exp_ids(self, exp, **kwargs):
        """Maps the feature ids to concrete names.

        Default behaviour is the identity function. Subclasses can implement
        this as they see fit.

        Args:
            exp: list of tuples [(id, weight), (id,weight)]
            kwargs: optional keyword arguments

        Returns:
            exp: list of tuples [(name, weight), (name, weight)...]
        """
        return exp

    def visualize_instance_html(self, exp, label, div_name, exp_object_name, **kwargs):
        """Produces html for visualizing the instance.

        Default behaviour does nothing. Subclasses can implement this as they
        see fit.

        Args:
             exp: list of tuples [(id, weight), (id,weight)]
             label: label id (integer)
             div_name: name of div object to be used for rendering(in js)
             exp_object_name: name of js explanation object
             kwargs: optional keyword arguments

        Returns:
             js code for visualizing the instance
        """
        return ""


class Explanation(object):
    """Object returned by explainers."""

    def __init__(self, domain_mapper, mode="classification", class_names=None, random_state=None):
        """

        Initializer.

        Args:
            domain_mapper: must inherit from DomainMapper class
            type: "classification" or "regression"
            class_names: list of class names (only used for classification)
            random_state: an integer or numpy.RandomState that will be used to
                generate random numbers. If None, the random state will be
                initialized using the internal numpy seed.
        """
        self.random_state = random_state
        self.mode = mode
        self.domain_mapper = domain_mapper
        self.local_exp = {}
        self.intercept = {}
        self.score = {}
        self.local_pred = {}
        if mode == "classification":
            self.class_names = class_names
            self.top_labels = None
            self.predict_proba = None
        elif mode == "regression":
            self.class_names = ["negative", "positive"]
            self.predicted_value = None
            self.min_value = 0.0
            self.max_value = 1.0
            self.dummy_label = 1
        else:
            raise LimeError(
                'Invalid explanation mode "{}". ' 'Should be either "classification" ' 'or "regression".'.format(mode)
            )

    def available_labels(self):
        """
        Returns the list of classification labels for which we have any explanations.
        """
        try:
            assert self.mode == "classification"
        except AssertionError:
            raise NotImplementedError("Not supported for regression explanations.")
        else:
            ans = self.top_labels if self.top_labels else self.local_exp.keys()
            return list(ans)

    def as_list(self, label=1, **kwargs):
        """Returns the explanation as a list.

        Args:
            label: desired label. If you ask for a label for which an
                explanation wasn't computed, will throw an exception.
                Will be ignored for regression explanations.
            kwargs: keyword arguments, passed to domain_mapper

        Returns:
            list of tuples (representation, weight), where representation is
            given by domain_mapper. Weight is a float.
        """
        label_to_use = label if self.mode == "classification" else self.dummy_label
        ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs)
        ans = [(x[0], float(x[1])) for x in ans]
        return ans

    def as_map(self):
        """Returns the map of explanations.

        Returns:
            Map from label to list of tuples (feature_id, weight).
        """
        return self.local_exp

    def as_pyplot_figure(self, label=1, **kwargs):
        """Returns the explanation as a pyplot figure.

        Will throw an error if you don't have matplotlib installed
        Args:
            label: desired label. If you ask for a label for which an
                   explanation wasn't computed, will throw an exception.
                   Will be ignored for regression explanations.
            kwargs: keyword arguments, passed to domain_mapper

        Returns:
            pyplot figure (barchart).
        """
        import matplotlib.pyplot as plt

        exp = self.as_list(label=label, **kwargs)
        fig = plt.figure()
        vals = [x[1] for x in exp]
        names = [x[0] for x in exp]
        vals.reverse()
        names.reverse()
        colors = ["green" if x > 0 else "red" for x in vals]
        pos = np.arange(len(exp)) + 0.5
        plt.barh(pos, vals, align="center", color=colors)
        plt.yticks(pos, names)
        if self.mode == "classification":
            title = "Local explanation for class %s" % self.class_names[label]
        else:
            title = "Local explanation"
        plt.title(title)
        return fig

    def show_in_notebook(self, labels=None, predict_proba=True, show_predicted_value=True, **kwargs):
        """Shows html explanation in ipython notebook.

        See as_html() for parameters.
        This will throw an error if you don't have IPython installed"""

        from IPython.core.display import HTML, display

        display(
            HTML(
                self.as_html(
                    labels=labels, predict_proba=predict_proba, show_predicted_value=show_predicted_value, **kwargs
                )
            )
        )

    def save_to_file(self, file_path, labels=None, predict_proba=True, show_predicted_value=True, **kwargs):
        """Saves html explanation to file. .

        Params:
            file_path: file to save explanations to

        See as_html() for additional parameters.

        """
        file_ = open(file_path, "w", encoding="utf8")
        file_.write(
            self.as_html(
                labels=labels, predict_proba=predict_proba, show_predicted_value=show_predicted_value, **kwargs
            )
        )
        file_.close()

    def as_html(self, labels=None, predict_proba=True, show_predicted_value=True, **kwargs):
        """Returns the explanation as an html page.

        Args:
            labels: desired labels to show explanations for (as barcharts).
                If you ask for a label for which an explanation wasn't
                computed, will throw an exception. If None, will show
                explanations for all available labels. (only used for classification)
            predict_proba: if true, add  barchart with prediction probabilities
                for the top classes. (only used for classification)
            show_predicted_value: if true, add  barchart with expected value
                (only used for regression)
            kwargs: keyword arguments, passed to domain_mapper

        Returns:
            code for an html page, including javascript includes.
        """

        def jsonize(x):
            return json.dumps(x, ensure_ascii=False)

        if labels is None and self.mode == "classification":
            labels = self.available_labels()

        this_dir, _ = os.path.split(__file__)
        bundle = open(os.path.join(this_dir, "bundle.js"), encoding="utf8").read()

        out = (
            """<html>
        <meta http-equiv="content-type" content="text/html; charset=UTF8">
        <head><script>%s </script></head><body>"""
            % bundle
        )
        random_id = id_generator(size=15, random_state=check_random_state(self.random_state))
        out += (
            """
        <div class="lime top_div" id="top_div%s"></div>
        """
            % random_id
        )

        predict_proba_js = ""
        if self.mode == "classification" and predict_proba:
            predict_proba_js = """
            var pp_div = top_div.append('div')
                                .classed('lime predict_proba', true);
            var pp_svg = pp_div.append('svg').style('width', '100%%');
            var pp = new lime.PredictProba(pp_svg, %s, %s);
            """ % (
                jsonize([str(x) for x in self.class_names]),
                jsonize(list(self.predict_proba.astype(float))),
            )

        predict_value_js = ""
        if self.mode == "regression" and show_predicted_value:
            # reference self.predicted_value
            # (svg, predicted_value, min_value, max_value)
            predict_value_js = """
                    var pp_div = top_div.append('div')
                                        .classed('lime predicted_value', true);
                    var pp_svg = pp_div.append('svg').style('width', '100%%');
                    var pp = new lime.PredictedValue(pp_svg, %s, %s, %s);
                    """ % (
                jsonize(float(self.predicted_value)),
                jsonize(float(self.min_value)),
                jsonize(float(self.max_value)),
            )

        exp_js = """var exp_div;
            var exp = new lime.Explanation(%s);
        """ % (
            jsonize([str(x) for x in self.class_names])
        )

        if self.mode == "classification":
            for label in labels:
                exp = jsonize(self.as_list(label))
                exp_js += """
                exp_div = top_div.append('div').classed('lime explanation', true);
                exp.show(%s, %d, exp_div);
                """ % (
                    exp,
                    label,
                )
        else:
            exp = jsonize(self.as_list())
            exp_js += """
            exp_div = top_div.append('div').classed('lime explanation', true);
            exp.show(%s, %s, exp_div);
            """ % (
                exp,
                self.dummy_label,
            )

        raw_js = """var raw_div = top_div.append('div');"""

        if self.mode == "classification":
            html_data = self.local_exp[labels[0]]
        else:
            html_data = self.local_exp[self.dummy_label]

        raw_js += self.domain_mapper.visualize_instance_html(
            html_data, labels[0] if self.mode == "classification" else self.dummy_label, "raw_div", "exp", **kwargs
        )
        out += """
        <script>
        var top_div = d3.select('#top_div%s').classed('lime top_div', true);
        %s
        %s
        %s
        %s
        </script>
        """ % (
            random_id,
            predict_proba_js,
            predict_value_js,
            exp_js,
            raw_js,
        )
        out += "</body></html>"

        return out
